Skip to content

Conversation

sshonTT
Copy link

@sshonTT sshonTT commented Oct 3, 2025

Rebase to upstream

zmelumian972 and others added 30 commits July 18, 2025 12:42
…#9501)

* Refactored jax device handling
* Removed option to use CPU jax array for CPU torch tensors. - changing jax devices after the fact will use different APIs
jazpurTT and others added 8 commits October 3, 2025 18:27
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <[email protected]>
… PJRT backend

This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code.
Key changes:
* Python API
    * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values).
    * Added torch_xla.set_custom_compile_options() utility for setting compile options globally.
    * Added internal binding _XLAC._set_custom_compile_options().
* C++ Runtime
    * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient.
    * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation.
    * Options are stringified before being passed to XLA for compatibility.
Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
…ation (#7)

This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs.

See pytorch#9541 for the upstream PR discussion and additional context.

* Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon.

* New implementation (WIP)

* Fix new implementation

* Fix visualize_tensor_sharding function for V2 shardings
@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

@AleksKnezevic I tried this out, but it seems to cause random segfaults. I’ll need to dig in further to figure out the root cause.

@AleksKnezevic
Copy link

Thanks @sshonTT, then please repoen and merge #9 while we investigate.

fix for api match
@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

Hi @hshahTT, @jazpurTT, @ddilbazTT,

I’ve rebased our branch with the upstream changes and verified it on my side, but I’d appreciate it if you could double-check that everything works correctly.

I’ve also built a wheel for testing, which you can find here:
wh-lb-57:/localdev/sshon/ws/pytorch/pytorch-xla/dist/torch_xla-2.9.0+git86bac8b-cp311-cp311-linux_x86_64.whl

Please let me know if it installs and runs fine on your end. Thanks!

@AleksKnezevic
Copy link

@sshonTT are you still seeing segfaults?

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

@AleksKnezevic No I don't see it now.

@AleksKnezevic
Copy link

That's great, what was causing them previously, any ideas?

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

I couldn’t root-cause it completely, but it turned out to be a system-level issue related to Torch Inductor’s mutex handling. It was failing because a mutex was already acquired by another process or context, likely left uncleared from a previous pytest run.

After releasing and reassigning the same IRD machine, the issue disappeared. I also verified it on another IRD machine to confirm that it works correctly now.

@AleksKnezevic
Copy link

awesome, thanks @sshonTT! Do we have a way of running CI with this wheel?

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

I’ve triggered this workflow run
to build a wheel. Once it’s ready, I’ll update the torch-xla version in tt-xla and test how it behaves. Other than that, I don’t currently have a concrete way to verify this change yet.

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

I think we have a build issue since here. Will find a way to get over this.

@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch 3 times, most recently from 616047b to 626b736 Compare October 7, 2025 16:15
Torch build option change to avoid build warning and error.
@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch from 626b736 to 27f7792 Compare October 7, 2025 16:27
@sshonTT
Copy link
Author

sshonTT commented Oct 8, 2025

Build success after turning off warning as an error, but there is an error when publish it. I think it is some related to s3 bucket credential,
image

@jazpurTT I believe you have experience on S3 bucket, so do you know what is going on and have any suggestion to fix it?

@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch 2 times, most recently from b1ebc54 to 7a014f8 Compare October 10, 2025 12:57
@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch from 7a014f8 to a5240f6 Compare October 10, 2025 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.